# Spectral GCN + Attention Recovery + LSTM
# This code trains and tests the GNN model for the COVID-19 infection prediction in Tokyo
# Author: Jiawei Xue, August 26, 2021
# Step 1: read and pack the traning and testing data
# Step 2: training epoch, training process, testing
# Step 3: build the model = spectral GCN + Attention Recovery + LSTM
# Step 4: main function
# Step 5: evaluation
# Step 6: visualization
import os
import csv
import json
import copy
import time
import random
import string
import argparse
import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN_LSTM
#torch.set_printoptions(precision=8)
#hyperparameter for the setting
X_day, Y_day = 21,21
#START_DATE, END_DATE = '20200414','20210207'
#START_DATE, END_DATE = '20200808','20210603'
START_DATE, END_DATE = '20200720','20210515'
WINDOW_SIZE = 7
#hyperparameter for the learning
DROPOUT, ALPHA = 0.50, 0.20
NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE = 100, 8, 0.0001
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2 = 6,4,3
infection_normalize_ratio = 100.0
web_search_normalize_ratio = 100.0
train_ratio = 0.7
validate_ratio = 0.1
SEED = 5
torch.manual_seed(SEED)
r = random.random
random.seed(5)
#1.total period (mobility+text):
#from 20200201 to 20210620: (29+31+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20)\
#= 335 + 171 = 506;
#2.number of zones: 23;
#3.infection period:
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#1. Mobility: functions 1.2 to 1.7
#2. Text: functions 1.8 to 1.14
#3. InfectionL: functions 1.15
#4. Preprocess: functions 1.16 to 1.24
#5. Learn: functions 1.25 to 1.26
#function 1.1
#get the central areas of Tokyo (e.g., the Special wards of Tokyo)
#return: a 23 zone shapefile
def read_tokyo_23():
folder = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/present_model_version10/tokyo_23"
file = "tokyo_23zones.shp"
path = os.path.join(folder,file)
data = gpd.read_file(path)
return data
##################1.Mobility#####################
#function 1.2
#get the average of two days' mobility (infection) records
def mob_inf_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = (record1[i]+record2[i])/2.0
return new_record
#function 1.3
#get the average of multiple days' mobility (infection) records
def mob_inf_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record:
if zone_id not in list(new_record.keys()):
new_record[zone_id] = record[zone_id]
else:
new_record[zone_id] += record[zone_id]
for new_record_key in new_record:
new_record[new_record_key] = new_record[new_record_key]*1.0/num_day
return new_record
#function 1.4
#generate the dateList: [20200101, 20200102, ..., 20211231]
def generate_dateList():
yearList = ["2020","2021"]
monthList = ["0"+str(i+1) for i in range(9)] + ["10","11","12"]
dayList = ["0"+str(i+1) for i in range(9)] + [str(i) for i in range(10,32)]
day_2020_num = [31,29,31,30,31,30,31,31,30,31,30,31]
day_2021_num = [31,28,31,30,31,30,31,31,30,31,30,31]
date_2020, date_2021 = list(), list()
for i in range(12):
for j in range(day_2020_num[i]):
date_2020.append(yearList[0] + monthList[i] + dayList[j])
for j in range(day_2021_num[i]):
date_2021.append(yearList[1] + monthList[i] + dayList[j])
date_2020_2021 = date_2020 + date_2021
return date_2020_2021
#function 1.5
#smooth the mobility (infection) data using the neighborhood average
#under a given window size
#dateList: [20200101, 20200102, ..., 20211231]
def mob_inf_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = mob_inf_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.6
#set the mobility (infection) of one day as zero
def mob_inf_average_null(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = 0
return new_record
#function 1.7
#read the mobility data from "mobility_feature_20200201.json"...
#return: all_mobility:{"20200201":{('123','123'):12345,...},...}
#20200201 to 20210620: 506 days
def read_mobility_data(jcode23):
all_mobility = dict()
mobilityFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/mobility_20210804"
mobilityNameList = os.listdir(mobilityFilePath)
for i in range(len(mobilityNameList)):
day_mobility = dict()
file_name = mobilityNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0] #get the day
file_path = mobilityFilePath + '/' + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
for key in df_file:
origin, dest = key.split("_")[0], key.split("_")[1]
if origin in jcode23 and dest in jcode23:
if origin == dest:
day_mobility[(origin, dest)] = 0.0 #ignore the inner-zone flow
else:
day_mobility[(origin, dest)] = df_file[key]
all_mobility[day] = day_mobility
#missing data
all_mobility["20201128"] = mob_inf_average(all_mobility,"20201127","20201129")
all_mobility["20210104"] = mob_inf_average(all_mobility, "20210103","20210105")
return all_mobility
##################2.Text#####################
#function 1.8
#get the average of two days' infection records
def text_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
zone_record1, zone_record2 = record1[i], record2[i]
new_zone_record = dict()
for j in zone_record1:
if j in zone_record2:
new_zone_record[j] = (zone_record1[j] + zone_record2[j])/2.0
new_record[i] = new_zone_record
return new_record
#function 1.9
#get the average of multiple days' text records
def text_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record: #zone_id
if zone_id not in new_record:
new_record[zone_id] = dict()
for j in record[zone_id]: #symptom
if j not in new_record[zone_id]:
new_record[zone_id][j] = record[zone_id][j]
else:
new_record[zone_id][j] += record[zone_id][j]
for zone_id in new_record:
for j in new_record[zone_id]:
new_record[zone_id][j] = new_record[zone_id][j]*1.0/num_day
return new_record
#function 1.10
#smooth the text data using the neighborhood average
#under a given window size
def text_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = text_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.11
#read the number of user points
def read_point_json():
with open('user_point/mobility_user_point.json') as point1:
user_point1 = json.load(point1)
with open('user_point/mobility_user_point_20210812.json') as point2:
user_point2 = json.load(point2)
user_point_all = dict()
for i in user_point1:
user_point_all[i] = user_point1[i]
for i in user_point2:
user_point_all[i] = user_point2[i]
user_point_all["20201128"] = user_point_all["20201127"] #data missing
user_point_all["20210104"] = user_point_all["20210103"] #data missing
return user_point_all
#function 1.12
#normalize the text search by the number of user points.
def normalize_text_user(all_text, user_point_all):
for day in all_text:
if day in user_point_all:
num_user = user_point_all[day]["num_user"]
all_text_day_new = dict()
all_text_day = all_text[day]
for zone in all_text_day:
if zone not in all_text_day_new:
all_text_day_new[zone] = dict()
for sym in all_text_day[zone]:
all_text_day_new[zone][sym] = all_text_day[zone][sym]*1.0/num_user
all_text[day] = all_text_day_new
return all_text
#function 1.13
#read the text data
#20200201 to 20210620: 506 days
#all_text = {"20200211":{"123":{"code":3,"fever":2,...},...},...}
def read_text_data(jcode23):
all_text = dict()
textFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/text_20210804"
textNameList = os.listdir(textFilePath)
for i in range(len(textNameList)):
day_text = dict()
file_name = textNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0]
file_path = textFilePath + "/" + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
new_dict = dict()
for key in df_file:
if key in jcode23:
new_dict[key] = {key1:df_file[key][key1]*1.0*web_search_normalize_ratio for key1 in df_file[key]}
#new_dict[key] = df_file[key]*WEB_SEARCH_RATIO
all_text[day] = new_dict
all_text["20201030"] = text_average(all_text, "20201029", "20201031") #data missing
return all_text
#function 1.14
#perform the min-max normalization for the text data.
def min_max_text_data(all_text,jcode23):
#calculate the min_max
#region_key: sym: [min,max]
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい', '吐き気', '嘔吐', '筋肉痛', '動悸', \
'副鼻腔炎', '発疹', 'くしゃみ', '倦怠感', '寒気', '脱水', \
'中咽頭', '関節痛', '不眠症', '睡眠障害', '鼻漏', '片頭痛', \
'多汗症', 'ほてり', '胸痛', '発汗', '無気力', '呼吸困難', \
'喘鳴', '目の痛み', '体の痛み', '無嗅覚症', '耳の痛み', \
'錯乱', '見当識障害', '胸の圧迫感', '鼻の乾燥', '耳感染症', \
'味覚消失', '上気道感染症', '眼感染症', '食欲減少'])
region_sym_min_max = dict()
for key in jcode23: #initialize
region_sym_min_max[key] = dict()
for sym in text_list:
region_sym_min_max[key][sym] = [1000000,0] #min, max
for day in all_text: #update
for key in jcode23:
for sym in text_list:
if sym in all_text[day][key]:
count = all_text[day][key][sym]
if count < region_sym_min_max[key][sym][0]:
region_sym_min_max[key][sym][0] = count
if count > region_sym_min_max[key][sym][1]:
region_sym_min_max[key][sym][1] = count
#print ("region_sym_min_max",region_sym_min_max)
for key in jcode23: #normalize
for sym in text_list:
min_count,max_count=region_sym_min_max[key][sym][0],region_sym_min_max[key][sym][1]
for day in all_text:
if sym in all_text[day][key]:
if max_count-min_count == 0:
all_text[day][key][sym] = 1
else:
all_text[day][key][sym] = (all_text[day][key][sym]-min_count)*1.0/(max_count-min_count)
#print("all_text[day][key][sym]",all_text[day][key][sym])
return all_text
##################3.Infection#####################
#function 1.15
#read the infection data
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#all_infection = {"20200201":{"123":1,"123":2}}
def read_infection_data(jcode23):
all_infection = dict()
infection_path = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/patient_20210725.json"
f = open(infection_path,)
df_file = json.load(f) #read the mobility file
f.close()
for zone_id in df_file:
for one_day in df_file[zone_id]:
daySplit = one_day.split("/")
year, month, day = daySplit[0], daySplit[1], daySplit[2]
if len(month) == 1:
month = "0" + month
if len(day) == 1:
day = "0" + day
new_date = year + month + day
if str(zone_id[0:5]) in jcode23:
if new_date not in all_infection:
all_infection[new_date] = {zone_id[0:5]:df_file[zone_id][one_day]*1.0/infection_normalize_ratio}
else:
all_infection[new_date][zone_id[0:5]] = df_file[zone_id][one_day]*1.0/infection_normalize_ratio
#missing
date_list = [str(20200316+i) for i in range(15)]
for date in date_list:
all_infection[date] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200514'] = mob_inf_average(all_infection,'20200513','20200515')
all_infection['20200519'] = mob_inf_average(all_infection,'20200518','20200520')
all_infection['20200523'] = mob_inf_average(all_infection,'20200522','20200524')
all_infection['20200530'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20200531'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20201231'] = mob_inf_average(all_infection,'20201230','20210101')
all_infection['20210611'] = mob_inf_average(all_infection,'20210610','20210612')
#outlier
all_infection['20200331'] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200910'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200911'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200511'] = mob_inf_average(all_infection,'20200510','20200512')
all_infection['20201208'] = mob_inf_average(all_infection,'20201207','20201209')
all_infection['20210208'] = mob_inf_average(all_infection,'20210207','20210209')
all_infection['20210214'] = mob_inf_average(all_infection,'20210213','20210215')
#calculate the subtraction
all_infection_subtraction = dict()
all_infection_subtraction['20200331'] = all_infection['20200331']
all_keys = list(all_infection.keys())
all_keys.sort()
for i in range(len(all_keys)-1):
record = dict()
for j in all_infection[all_keys[i+1]]:
record[j] = all_infection[all_keys[i+1]][j] - all_infection[all_keys[i]][j]
all_infection_subtraction[all_keys[i+1]] = record
return all_infection_subtraction, all_infection
##################4.Preprocess#####################
#function 1.16
#ensemble the mobility, text, and infection.
#all_mobility = {"20200201":{('123','123'):12345,...},...}
#all_text = {"20200201":{"123":{"cold":3,"fever":2,...},...},...}
#all_infection = {"20200316":{"123":1,"123":2}}
#all_x_y = {"0":[[mobility_1,text_1, ..., mobility_x_day,text_x_day], [infection_1,...,infection_y_day],\
#[infection_1,...,infection_x_day]],0}
#x_days, y_days: use x_days to predict y_days
def ensemble(all_mobility, all_text, all_infection, x_days, y_days, all_day_list):
all_x_y = dict()
for j in range(len(all_day_list) - x_days - y_days + 1):
x_sample, y_sample, x_sample_infection = list(), list(), list()
#add the data from all_day_list[0+j] to all_day_list[x_days-1+j]
for k in range(x_days):
day = all_day_list[k + j]
x_sample.append(all_mobility[day])
x_sample.append(all_text[day])
x_sample_infection.append(all_infection[day]) #concatenate with the infection data
#add the data from all_day_list[x_days+j] to all_day_list[x_days+y_day-1+j]
for k in range(y_days):
day = all_day_list[x_days + k + j]
y_sample.append(all_infection[day])
all_x_y[str(j)] = [x_sample, y_sample, x_sample_infection,j]
return all_x_y
#function 1.17
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_key = [all_x_y[str(i)] for i in range(n_train)]
validate_key = [all_x_y[str(i+n_train)] for i in range(n_validate)]
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key
##function 1.18
#the second data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_2(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train+n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if i % 9 == 8:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.19
#the third data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_3(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n - n_train - n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train + n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if (n_train + n_validate-i) % 2 == 0 and (n_train + n_validate-i) <= 2*n_validate:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.20
#find the mobility data starting from the day, which is x_days before the start_date
#start_date = "20200331", x_days = 7
def sort_date(all_mobility, start_date, x_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
mobility_date_cut = mobility_date_list[idx-x_days:]
return mobility_date_cut
#function 1.21
#find the mobility data starting from the day, which is x_days before the start_date,
#ending at the day, which is y_days after the end_date
#start_date = "20200331", x_days = 7
def sort_date_2(all_mobility, start_date, x_days, end_date, y_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
idx2 = mobility_date_list.index(end_date)
mobility_date_cut = mobility_date_list[idx-x_days:idx2+y_days]
return mobility_date_cut
#function 1.22
#get the mappings from zone id to id, text id to id.
#get zone_text_to_idx
def get_zone_text_to_idx(all_infection):
zone_list = list(set(all_infection["20200401"].keys()))
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい'])
zone_list.sort()
zone_dict = {str(zone_list[i]):i for i in range(len(zone_list))}
text_dict = {str(text_list[i]):i for i in range(len(text_list))}
return zone_dict, text_dict
#function 1.23
#change the data format to matrix
#zoneid_to_idx = {"13101":0, "13102":1, ..., "13102":22}
#sym_to_idx = {"cough":0}
#mobility: {('13101', '13101'): 709973, ...}
#text: {'13101': {'痛み': 51,...},...} text
#infection: {'13101': 50, '13102': 137, '13103': 401,...}
#data_type = {"mobility", "text", "infection"}
def to_matrix(zoneid_to_idx, sym_to_idx, input_data, data_type):
n_zone, n_text = len(zoneid_to_idx), len(sym_to_idx)
if data_type == "mobility":
result = np.zeros((n_zone, n_zone))
for key in input_data:
from_id, to_id = key[0], key[1]
from_idx, to_idx = zoneid_to_idx[from_id], zoneid_to_idx[to_id]
result[from_idx][to_idx] += input_data[key]
if data_type == "text":
result = np.zeros((n_zone, n_text))
for key1 in input_data:
for key2 in input_data[key1]:
if key1 in list(zoneid_to_idx.keys()) and key2 in list(sym_to_idx.keys()):
zone_idx, text_idx = zoneid_to_idx[key1], sym_to_idx[key2]
result[zone_idx][text_idx] += input_data[key1][key2]
if data_type == "infection":
result = np.zeros(n_zone)
for key in input_data:
zone_idx = zoneid_to_idx[key]
result[zone_idx] += input_data[key]
return result
#function 1.24
#change the data to the matrix format
def change_to_matrix(data, zoneid_to_idx, sym_to_idx):
data_result = list()
for i in range(len(data)):
combine1, combine2 = list(), list()
combine3 = list() #NEW
mobility_text = data[i][0]
x_infection_all = data[i][2] #the x_days infection data
day_order = data[i][3] #NEW the order of the day
for j in range(round(len(mobility_text)*1.0/2)):
mobility, text = mobility_text[2*j], mobility_text[2*j+1]
x_infection = x_infection_all[j] #NEW
new_mobility = to_matrix(zoneid_to_idx, sym_to_idx, mobility, "mobility")
new_text = to_matrix(zoneid_to_idx, sym_to_idx, text, "text")
combine1.append(new_mobility)
combine1.append(new_text)
new_x_infection = to_matrix(zoneid_to_idx, sym_to_idx, x_infection, "infection") #NEW
combine3.append(new_x_infection) #NEW
for j in range(len(data[i][1])):
infection = data[i][1][j]
new_infection = to_matrix(zoneid_to_idx, sym_to_idx, infection, "infection")
combine2.append(new_infection)
data_result.append([combine1,combine2,combine3,day_order]) #mobility/text; infection_y; infection_x; day_order
return data_result
##################5.learn#####################
#function 1.25
def visual_loss(e_losses, vali_loss, test_loss):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1,y2,y3 = copy.copy(e_losses), copy.copy(vali_loss), copy.copy(test_loss)
plt.plot(x,y1,linewidth=1, label="train")
plt.plot(x,y2,linewidth=1, label="validate")
plt.plot(x,y3,linewidth=1, label="test")
plt.legend()
plt.title('Loss decline on entire training/validation/testing data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 1.26
def visual_loss_train(e_losses):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1 = copy.copy(e_losses)
plt.plot(x,y1,linewidth=1, label="train")
plt.legend()
plt.title('Loss decline on entire training data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 2.1
#normalize each column of the input mobility matrix as one
def normalize_column_one(input_matrix):
column_sum = np.sum(input_matrix, axis=0)
row_num, column_num = len(input_matrix), len(input_matrix[0])
for i in range(row_num):
for j in range(column_num):
input_matrix[i][j] = input_matrix[i][j]*1.0/column_sum[j]
return input_matrix
#function 2.2
#evalute the trained_model on validation or testing data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x)
loss = criterion(y_hat.float(), y_real.float()) ###Calculate the loss
return loss, y_hat, y_real
#function 2.3
#convert the mobility matrix in x_batch in a following way
#normalize the flow between zones so that the in-flow of each zone is 1.
def convertAdj(x_batch):
#x_batch:(n_batch, 0/1, 2*i+1)
x_batch_new = copy.copy(x_batch)
n_batch = len(x_batch)
days = round(len(x_batch[0][0])/2)
for i in range(n_batch):
for j in range(days):
mobility_matrix = x_batch[i][0][2*j]
x_batch_new[i][0][2*j] = normalize_column_one(mobility_matrix) #20210818
return x_batch_new
#function 2.4
#a training epoch
def train_epoch_option(model, opt, criterion, trainX_c, trainY_c, batch_size):
model.train()
losses = []
batch_num = 0
for beg_i in range(0, len(trainX_c), batch_size):
batch_num += 1
if batch_num % 16 ==0:
print ("batch_num: ", batch_num, "total batch number: ", int(len(trainX_c)/batch_size))
x_batch = trainX_c[beg_i:beg_i+batch_size]
y_batch = torch.tensor(trainY_c[beg_i:beg_i+batch_size])
opt.zero_grad()
x_batch = convertAdj(x_batch) #conduct the column normalization
y_hat = model.run_specGCN_lstm(x_batch) ###Attention
loss = criterion(y_hat.float(), y_batch.float()) #MSE loss
#opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.data.numpy())
return sum(losses)/float(len(losses)), model
#function 2.5
#multiple training epoch
def train_process(train_data, lr, num_epochs, net, criterion, bs, vali_data, test_data):
opt = optim.Adam(net.parameters(), lr, betas = (0.9,0.999), weight_decay=0)
train_y = [train_data[i][1] for i in range(len(train_data))]
e_losses = list()
e_losses_vali = list()
e_losses_test = list()
time00 = time.time()
for e in range(num_epochs):
time1 = time.time()
print ("current epoch: ",e, "total epoch: ", num_epochs)
number_list = list(range(len(train_data)))
random.shuffle(number_list, random = r)
trainX_sample = [train_data[number_list[j]] for j in range(len(number_list))]
trainY_sample = [train_y[number_list[j]] for j in range(len(number_list))]
loss, net = train_epoch_option(net, opt, criterion, trainX_sample, trainY_sample, bs)
print ("train loss", loss*infection_normalize_ratio*infection_normalize_ratio)
e_losses.append(loss*infection_normalize_ratio*infection_normalize_ratio)
loss_vali, y_hat_vali, y_real_vali = validate_test_process(net, vali_data)
loss_test, y_hat_test, y_real_test = validate_test_process(net, test_data)
e_losses_vali.append(float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
e_losses_test.append(float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
print ("validate loss", float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
print ("test loss", float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
if e>=2 and (e+1)%10 ==0:
visual_loss(e_losses, e_losses_vali, e_losses_test)
visual_loss_train(e_losses)
time2 = time.time()
print ("running time for this epoch:", time2 - time1)
time01 = time.time()
print ("---------------------------------------------------------------")
print ("---------------------------------------------------------------")
#print ("total running time until now:", time01 - time00)
#print ("------------------------------------------------")
#print("specGCN_weight", net.specGCN.layer1.W)
#print("specGCN_weight_grad", net.specGCN.layer1.W.grad)
#print ("------------------------------------------------")
#print("memory decay matrix", net.v)
#print("memory decay matrix grad", net.v.grad)
#print ("------------------------------------------------")
#print ("lstm weight", net.lstm.all_weights[0][0])
#print ("lstm weight grad", net.lstm.all_weights[0][0].grad)
#print ("------------------------------------------------")
#print ("fc1.weight", net.fc1.weight)
#print ("fc1 weight grd", net.fc1.weight.grad)
#print ("---------------------------------------------------------------")
#print ("---------------------------------------------------------------")
return e_losses, net
#function 3.1
def read_data():
jcode23 = list(read_tokyo_23()["JCODE"]) #1.1 get the tokyo 23 zone shapefile
all_mobility = read_mobility_data(jcode23) #1.2 read the mobility data
all_text = read_text_data(jcode23) #1.3 read the text data
all_infection, all_infection_cum = read_infection_data(jcode23) #1.4 read the infection data
#smooth the data using 7-days average
window_size = WINDOW_SIZE #20210818
dateList = generate_dateList() #20210818
all_mobility = mob_inf_smooth(all_mobility, window_size, dateList) #20210818
all_infection = mob_inf_smooth(all_infection, window_size, dateList) #20210818
#smooth, user, min-max.
point_json = read_point_json() #20210821
all_text = normalize_text_user(all_text, point_json) #20210821
all_text = text_smooth(all_text, window_size, dateList) #20210818
all_text = min_max_text_data(all_text,jcode23) #20210820
x_days, y_days = X_day, Y_day
mobility_date_cut = sort_date_2(all_mobility, START_DATE, x_days, END_DATE, y_days)
all_x_y = ensemble(all_mobility, all_text, all_infection, x_days, y_days, mobility_date_cut)
train_original, validate_original, test_original, train_list, validation_list =\
split_data_3(all_x_y,train_ratio,validate_ratio)
zone_dict, text_dict = get_zone_text_to_idx(all_infection) #get zone_idx, text_idx
train_x_y = change_to_matrix(train_original, zone_dict, text_dict) #get train
print ("train_x_y_shape",len(train_x_y),"train_x_y_shape[0]",len(train_x_y[0]))
validate_x_y = change_to_matrix(validate_original, zone_dict, text_dict) #get validate
test_x_y = change_to_matrix(test_original, zone_dict, text_dict) #get test
print (len(train_x_y)) #300
print (len(train_x_y[0][0])) #14
print (np.shape(train_x_y[0][0][0])) #(23,23)
print (np.shape(train_x_y[0][0][1])) #(23,43)
#print ("---------------------------------finish data reading and preprocessing------------------------------------")
return train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, train_original, validate_original, test_original, train_list, validation_list
#function 3.2
#train the model
def model_train(train_x_y, vali_data, test_data):
#3.2.1 define the model
input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2 = len(train_x_y[0][0][1][1]),\
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2
dropout_1, alpha_1, N = DROPOUT, ALPHA, len(train_x_y[0][0][1])
G_L_Model = SpecGCN_LSTM(X_day, Y_day, input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2, dropout_1,N) ###Attention
#3.2.2 train the model
num_epochs, batch_size, learning_rate = NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE #model train
criterion = nn.MSELoss()
e_losses, trained_model = train_process(train_x_y, learning_rate, num_epochs, G_L_Model, criterion, batch_size,\
vali_data, test_data)
return e_losses, trained_model
#function 3.3
#evaluate the error on validation (or testing) data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x) ###Attention
loss = criterion(y_hat.float(), y_real.float())
return loss, y_hat, y_real
#4.1
#read the data
train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, \
train_original, validate_original, test_original, train_list, validation_list =\
read_data()
#train_x_y, validate_x_y, test_x_y = normalize(train_x_y, validate_x_y, test_x_y)
#train_x_y = train_x_y[0:30]
print (len(train_x_y))
print ("---------------------------------finish data preparation------------------------------------")
train_x_y_shape 210 train_x_y_shape[0] 4 210 42 (23, 23) (23, 8) 210 ---------------------------------finish data preparation------------------------------------
#4.2
#train the model
e_losses, trained_model = model_train(train_x_y, validate_x_y, test_x_y)
print ("---------------------------finish model training-------------------------")
current epoch: 0 total epoch: 100 batch_num: 16 total batch number: 26 train loss 417.9716208535764 validate loss 306.71097338199615 test loss 438.00219893455505 running time for this epoch: 15.458380699157715 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 1 total epoch: 100 batch_num: 16 total batch number: 26 train loss 378.7486238784536 validate loss 303.4638985991478 test loss 410.46880185604095 running time for this epoch: 15.448132514953613 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 2 total epoch: 100 batch_num: 16 total batch number: 26 train loss 365.470887541219 validate loss 299.2762066423893 test loss 379.06549870967865 running time for this epoch: 15.368163108825684 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 3 total epoch: 100 batch_num: 16 total batch number: 26 train loss 345.97584701798576 validate loss 293.42496767640114 test loss 341.3648158311844 running time for this epoch: 15.853237390518188 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 4 total epoch: 100 batch_num: 16 total batch number: 26 train loss 327.22532455981883 validate loss 287.3804233968258 test loss 305.132232606411 running time for this epoch: 16.097155809402466 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 5 total epoch: 100 batch_num: 16 total batch number: 26 train loss 308.7711418737416 validate loss 280.2027203142643 test loss 270.9529176354408 running time for this epoch: 16.54499125480652 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 6 total epoch: 100 batch_num: 16 total batch number: 26 train loss 294.67212536406737 validate loss 270.7592211663723 test loss 238.01786825060844 running time for this epoch: 16.40885591506958 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 7 total epoch: 100 batch_num: 16 total batch number: 26 train loss 280.1055480139675 validate loss 258.1782825291157 test loss 209.4724215567112 running time for this epoch: 16.51949977874756 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 8 total epoch: 100 batch_num: 16 total batch number: 26 train loss 269.86572011891343 validate loss 243.41557174921036 test loss 185.42908132076263 running time for this epoch: 16.747636556625366 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 9 total epoch: 100 batch_num: 16 total batch number: 26 train loss 252.74953304755465 validate loss 225.72271525859833 test loss 166.0982146859169
running time for this epoch: 17.187342405319214 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 10 total epoch: 100 batch_num: 16 total batch number: 26 train loss 258.6636078302507 validate loss 206.96524530649185 test loss 148.6186496913433 running time for this epoch: 16.750561475753784 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 11 total epoch: 100 batch_num: 16 total batch number: 26 train loss 230.6269417965302 validate loss 186.10235303640366 test loss 132.18523003160954 running time for this epoch: 16.753048181533813 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 12 total epoch: 100 batch_num: 16 total batch number: 26 train loss 221.03159257245287 validate loss 164.36995938420296 test loss 118.42284351587296 running time for this epoch: 17.152021646499634 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 13 total epoch: 100 batch_num: 16 total batch number: 26 train loss 227.94409396333828 validate loss 148.82631599903107 test loss 109.81934145092964 running time for this epoch: 17.09384560585022 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 14 total epoch: 100 batch_num: 16 total batch number: 26 train loss 202.40490324795246 validate loss 132.38388113677502 test loss 100.48043914139271 running time for this epoch: 17.184275150299072 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 15 total epoch: 100 batch_num: 16 total batch number: 26 train loss 194.3508533898879 validate loss 118.43916960060596 test loss 94.00967508554459 running time for this epoch: 16.86436438560486 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 16 total epoch: 100 batch_num: 16 total batch number: 26 train loss 196.27666480287357 validate loss 106.75572790205479 test loss 86.39315143227577 running time for this epoch: 16.87771439552307 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 17 total epoch: 100 batch_num: 16 total batch number: 26 train loss 181.3606904060752 validate loss 96.85984812676907 test loss 82.08923041820526 running time for this epoch: 17.120392322540283 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 18 total epoch: 100 batch_num: 16 total batch number: 26 train loss 186.35736278223771 validate loss 89.77429941296577 test loss 78.13549600541592 running time for this epoch: 17.165250301361084 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 19 total epoch: 100 batch_num: 16 total batch number: 26 train loss 169.8294564805649 validate loss 82.02112279832363 test loss 74.40332788974047
running time for this epoch: 17.454917907714844 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 20 total epoch: 100 batch_num: 16 total batch number: 26 train loss 164.02025448365345 validate loss 76.02673955261707 test loss 70.26772014796734 running time for this epoch: 17.030155181884766 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 21 total epoch: 100 batch_num: 16 total batch number: 26 train loss 160.01665240360632 validate loss 71.97089027613401 test loss 68.13614163547754 running time for this epoch: 17.13310480117798 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 22 total epoch: 100 batch_num: 16 total batch number: 26 train loss 156.7840893511419 validate loss 68.65616422146559 test loss 66.15381687879562 running time for this epoch: 17.087868690490723 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 23 total epoch: 100 batch_num: 16 total batch number: 26 train loss 151.99173509177785 validate loss 65.07563404738903 test loss 64.06016647815704 running time for this epoch: 17.095673322677612 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 24 total epoch: 100 batch_num: 16 total batch number: 26 train loss 148.19561581437787 validate loss 62.5916151329875 test loss 63.38438019156456 running time for this epoch: 17.205768823623657 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 25 total epoch: 100 batch_num: 16 total batch number: 26 train loss 172.90235152123154 validate loss 60.49708463251591 test loss 62.67072167247534 running time for this epoch: 17.201245069503784 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 26 total epoch: 100 batch_num: 16 total batch number: 26 train loss 141.97096547663764 validate loss 59.22661628574133 test loss 58.66761785000563 running time for this epoch: 17.212970733642578 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 27 total epoch: 100 batch_num: 16 total batch number: 26 train loss 154.9931717346664 validate loss 58.71174391359091 test loss 58.14725998789072 running time for this epoch: 17.429460763931274 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 28 total epoch: 100 batch_num: 16 total batch number: 26 train loss 140.63723040399728 validate loss 58.960942551493645 test loss 56.66694603860378 running time for this epoch: 17.457489728927612 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 29 total epoch: 100 batch_num: 16 total batch number: 26 train loss 140.31646842205967 validate loss 58.10135509818792 test loss 55.77628035098314
running time for this epoch: 17.900179147720337 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 30 total epoch: 100 batch_num: 16 total batch number: 26 train loss 134.4046522690742 validate loss 58.579519391059875 test loss 55.42920436710119 running time for this epoch: 17.475611209869385 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 31 total epoch: 100 batch_num: 16 total batch number: 26 train loss 141.91800366259284 validate loss 58.266716077923775 test loss 55.3625775501132 running time for this epoch: 17.436043977737427 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 32 total epoch: 100 batch_num: 16 total batch number: 26 train loss 131.9578075264063 validate loss 58.46631247550249 test loss 55.37671037018299 running time for this epoch: 17.348163604736328 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 33 total epoch: 100 batch_num: 16 total batch number: 26 train loss 127.45785151815248 validate loss 59.72000304609537 test loss 53.90392616391182 running time for this epoch: 17.516541957855225 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 34 total epoch: 100 batch_num: 16 total batch number: 26 train loss 127.21705336675599 validate loss 59.20489318668842 test loss 54.594590328633785 running time for this epoch: 17.484910249710083 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 35 total epoch: 100 batch_num: 16 total batch number: 26 train loss 127.56032903713208 validate loss 60.61756983399391 test loss 54.686879739165306 running time for this epoch: 17.405664443969727 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 36 total epoch: 100 batch_num: 16 total batch number: 26 train loss 126.62659899366122 validate loss 60.16986910253763 test loss 55.88961765170097 running time for this epoch: 17.50118398666382 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 37 total epoch: 100 batch_num: 16 total batch number: 26 train loss 134.73004525251412 validate loss 61.42263300716877 test loss 54.847002029418945 running time for this epoch: 17.454529285430908 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 38 total epoch: 100 batch_num: 16 total batch number: 26 train loss 121.65009932316563 validate loss 64.12073969841003 test loss 53.873201832175255 running time for this epoch: 17.470115900039673 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 39 total epoch: 100 batch_num: 16 total batch number: 26 train loss 132.76336335197644 validate loss 64.52123168855906 test loss 54.27561700344086
running time for this epoch: 18.321196794509888 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 40 total epoch: 100 batch_num: 16 total batch number: 26 train loss 120.04540267365951 validate loss 65.53907878696918 test loss 53.56431007385254 running time for this epoch: 17.81260061264038 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 41 total epoch: 100 batch_num: 16 total batch number: 26 train loss 131.59193710803433 validate loss 65.39614871144295 test loss 55.143022909760475 running time for this epoch: 17.91212487220764 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 42 total epoch: 100 batch_num: 16 total batch number: 26 train loss 117.77610850037523 validate loss 68.70707497000694 test loss 53.779641166329384 running time for this epoch: 17.915064096450806 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 43 total epoch: 100 batch_num: 16 total batch number: 26 train loss 117.26960324233882 validate loss 68.99862084537745 test loss 53.93038038164377 running time for this epoch: 17.90891170501709 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 44 total epoch: 100 batch_num: 16 total batch number: 26 train loss 116.27486975932564 validate loss 69.28022019565105 test loss 55.1025802269578 running time for this epoch: 18.118045568466187 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 45 total epoch: 100 batch_num: 16 total batch number: 26 train loss 116.01506577183804 validate loss 70.8028906956315 test loss 54.50721364468336 running time for this epoch: 17.96769380569458 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 46 total epoch: 100 batch_num: 16 total batch number: 26 train loss 116.18900061067608 validate loss 71.97763305157423 test loss 54.67626266181469 running time for this epoch: 17.92027735710144 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 47 total epoch: 100 batch_num: 16 total batch number: 26 train loss 114.6098642816974 validate loss 72.57412187755108 test loss 54.832128807902336 running time for this epoch: 17.933958292007446 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 48 total epoch: 100 batch_num: 16 total batch number: 26 train loss 117.00230892057772 validate loss 72.5493161007762 test loss 55.505954660475254 running time for this epoch: 17.99398422241211 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 49 total epoch: 100 batch_num: 16 total batch number: 26 train loss 114.12933736142736 validate loss 73.0204489082098 test loss 56.33239634335041
running time for this epoch: 18.01486587524414 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 50 total epoch: 100 batch_num: 16 total batch number: 26 train loss 125.23211271467585 validate loss 75.24166721850634 test loss 55.31095899641514 running time for this epoch: 17.690500497817993 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 51 total epoch: 100 batch_num: 16 total batch number: 26 train loss 113.34797643401005 validate loss 78.25678214430809 test loss 53.62390074878931 running time for this epoch: 17.603981494903564 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 52 total epoch: 100 batch_num: 16 total batch number: 26 train loss 113.90373585055823 validate loss 79.18015122413635 test loss 54.383231326937675 running time for this epoch: 17.680060386657715 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 53 total epoch: 100 batch_num: 16 total batch number: 26 train loss 115.54538358554795 validate loss 79.03178222477436 test loss 54.13837730884552 running time for this epoch: 17.62855315208435 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 54 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.27687855285626 validate loss 77.71139033138752 test loss 55.635860189795494 running time for this epoch: 17.606141805648804 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 55 total epoch: 100 batch_num: 16 total batch number: 26 train loss 111.75754850868273 validate loss 78.8652803748846 test loss 55.315359495580196 running time for this epoch: 17.70355534553528 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 56 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.43012587160423 validate loss 80.22592402994633 test loss 54.92083728313446 running time for this epoch: 17.73955512046814 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 57 total epoch: 100 batch_num: 16 total batch number: 26 train loss 111.14760560707914 validate loss 78.50335910916328 test loss 56.013171561062336 running time for this epoch: 17.769867181777954 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 58 total epoch: 100 batch_num: 16 total batch number: 26 train loss 111.02399685316615 validate loss 81.69082924723625 test loss 54.51230797916651 running time for this epoch: 17.554762363433838 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 59 total epoch: 100 batch_num: 16 total batch number: 26 train loss 110.56129178204745 validate loss 81.24347776174545 test loss 55.57123105973005
running time for this epoch: 17.88039207458496 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 60 total epoch: 100 batch_num: 16 total batch number: 26 train loss 110.47835660132546 validate loss 82.85642601549625 test loss 54.47376053780317 running time for this epoch: 17.5984525680542 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 61 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.72172085699384 validate loss 84.4094529747963 test loss 54.338835179805756 running time for this epoch: 17.7717068195343 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 62 total epoch: 100 batch_num: 16 total batch number: 26 train loss 110.56104890519269 validate loss 83.61273445188999 test loss 54.675922729074955 running time for this epoch: 17.579650402069092 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 63 total epoch: 100 batch_num: 16 total batch number: 26 train loss 114.8348870790667 validate loss 85.03847755491734 test loss 53.857164457440376 running time for this epoch: 17.614727020263672 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 64 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.45954847287524 validate loss 85.21848358213902 test loss 54.982760921120644 running time for this epoch: 17.67832589149475 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 65 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.34254706457808 validate loss 86.34068071842194 test loss 54.81707863509655 running time for this epoch: 17.639902591705322 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 66 total epoch: 100 batch_num: 16 total batch number: 26 train loss 110.02776152626782 validate loss 86.86982095241547 test loss 53.985691629350185 running time for this epoch: 17.64319896697998 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 67 total epoch: 100 batch_num: 16 total batch number: 26 train loss 111.00810133265678 validate loss 87.02140301465988 test loss 54.23382855951786 running time for this epoch: 17.6591374874115 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 68 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.31792716426705 validate loss 86.27869188785553 test loss 54.53755147755146 running time for this epoch: 17.543041467666626 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 69 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.33437780497802 validate loss 86.21047250926495 test loss 54.57475781440735
running time for this epoch: 18.1369845867157 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 70 total epoch: 100 batch_num: 16 total batch number: 26 train loss 108.78722295509998 validate loss 88.2918480783701 test loss 53.79617214202881 running time for this epoch: 17.614197731018066 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 71 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.1310968277631 validate loss 90.2875792235136 test loss 53.3164432272315 running time for this epoch: 17.650712251663208 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 72 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.53005377410186 validate loss 90.51908738911152 test loss 53.91379352658987 running time for this epoch: 17.732501745224 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 73 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.11472084849245 validate loss 87.31313049793243 test loss 54.90529350936413 running time for this epoch: 17.586791515350342 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 74 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.2281975026622 validate loss 86.03234775364399 test loss 55.2273727953434 running time for this epoch: 17.54126811027527 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 75 total epoch: 100 batch_num: 16 total batch number: 26 train loss 108.47893733255289 validate loss 90.0151114910841 test loss 53.32834552973509 running time for this epoch: 18.51062846183777 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 76 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.85161474443696 validate loss 88.37387897074223 test loss 53.26088052242994 running time for this epoch: 17.585965394973755 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 77 total epoch: 100 batch_num: 16 total batch number: 26 train loss 118.41902147150702 validate loss 90.12818336486816 test loss 52.937837317585945 running time for this epoch: 17.537334442138672 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 78 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.55301022005301 validate loss 91.46795608103275 test loss 51.806652918457985 running time for this epoch: 17.427629947662354 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 79 total epoch: 100 batch_num: 16 total batch number: 26 train loss 112.92930597577381 validate loss 91.44378826022148 test loss 52.249254658818245
running time for this epoch: 17.181960105895996 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 80 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.7755118927194 validate loss 93.65084581077099 test loss 51.05310585349798 running time for this epoch: 16.90129566192627 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 81 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.55884047183726 validate loss 91.45685471594334 test loss 52.400310523808 running time for this epoch: 17.02692699432373 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 82 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.68594418617863 validate loss 87.89722807705402 test loss 54.313959553837776 running time for this epoch: 16.95948815345764 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 83 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.44844468448449 validate loss 91.93701669573784 test loss 51.63819994777441 running time for this epoch: 16.98026418685913 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 84 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.60523858321484 validate loss 93.07141415774822 test loss 51.46917421370745 running time for this epoch: 17.171870708465576 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 85 total epoch: 100 batch_num: 16 total batch number: 26 train loss 110.82052037602774 validate loss 90.46006947755814 test loss 52.72564012557268 running time for this epoch: 17.083730220794678 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 86 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.98091523307893 validate loss 90.76557122170925 test loss 53.06539125740528 running time for this epoch: 17.255159854888916 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 87 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.79120122875881 validate loss 91.18388406932354 test loss 51.74027755856514 running time for this epoch: 17.147231817245483 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 88 total epoch: 100 batch_num: 16 total batch number: 26 train loss 109.91087649017572 validate loss 93.19813922047615 test loss 51.72314587980509 running time for this epoch: 17.086824655532837 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 89 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.7098510382628 validate loss 89.1291256994009 test loss 52.240677177906036
running time for this epoch: 17.379044771194458 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 90 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.7019047870956 validate loss 91.76688268780708 test loss 52.046626806259155 running time for this epoch: 17.096171140670776 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 91 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.2757396226128 validate loss 91.29362180829048 test loss 51.72927398234606 running time for this epoch: 16.314340591430664 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 92 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.80534482802506 validate loss 92.22296066582203 test loss 51.38679873198271 running time for this epoch: 16.24642777442932 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 93 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.23334255931829 validate loss 90.42737074196339 test loss 51.86159163713455 running time for this epoch: 16.37497353553772 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 94 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.184746319842 validate loss 89.88216519355774 test loss 52.59794183075428 running time for this epoch: 16.310190439224243 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 95 total epoch: 100 batch_num: 16 total batch number: 26 train loss 107.09704400074702 validate loss 90.70559404790401 test loss 51.76751874387264 running time for this epoch: 16.32903265953064 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 96 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.30330549449556 validate loss 92.3113152384758 test loss 51.09037272632122 running time for this epoch: 16.358928442001343 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 97 total epoch: 100 batch_num: 16 total batch number: 26 train loss 105.78960926202987 validate loss 93.06204505264759 test loss 51.346332766115665 running time for this epoch: 16.229039192199707 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 98 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.42063730672278 validate loss 91.39150381088257 test loss 51.40568129718304 running time for this epoch: 16.193135023117065 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 99 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.13771780669965 validate loss 92.4329087138176 test loss 50.989785231649876
running time for this epoch: 16.497001886367798 --------------------------------------------------------------- --------------------------------------------------------------- ---------------------------finish model training-------------------------
#4.3
print (len(train_x_y))
print (len(validate_x_y))
print (len(test_x_y))
#4.3.1 model validation
validation_result, validate_hat, validate_real = validate_test_process(trained_model, validate_x_y)
print ("---------------------------------finish model validation------------------------------------")
print (len(validate_hat))
print (len(validate_real))
#4.3.2 model testing
#4.4. model test
test_result, test_hat, test_real = validate_test_process(trained_model, test_x_y)
print ("---------------------------------finish model testing------------------------------------")
print (len(test_real))
print (len(test_hat))
210 30 60 ---------------------------------finish model validation------------------------------------ 30 30 ---------------------------------finish model testing------------------------------------ 60 60
#5.1 RMSE, MAPE, MAE, RMSLE
def RMSELoss(yhat,y):
return float(torch.sqrt(torch.mean((yhat-y)**2)))
def MAPELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), y)))
def MAELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), 1)))
def RMSLELoss(yhat,y):
log_yhat = torch.log(yhat+1)
log_y = torch.log(y+1)
return float(torch.sqrt(torch.mean((log_yhat-log_y)**2)))
#compute RMSE
rmse_validate = list()
rmse_test = list()
for i in range(len(validate_x_y)):
rmse_validate.append(float(RMSELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
rmse_test.append(float(RMSELoss(test_hat[i],test_real[i])))
print ("rmse_validate mean", np.mean(rmse_validate))
print ("rmse_test mean", np.mean(rmse_test))
#compute MAE
mae_validate = list()
mae_test = list()
for i in range(len(validate_x_y)):
mae_validate.append(float(MAELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
mae_test.append(float(MAELoss(test_hat[i],test_real[i])))
print ("mae_validate mean", np.mean(mae_validate))
print ("mae_test mean", np.mean(mae_test))
#show RMSE and MAE together
mae_validate, rmse_validate, mae_test, rmse_test =\
np.array(mae_validate)*infection_normalize_ratio, np.array(rmse_validate)*infection_normalize_ratio,\
np.array(mae_test)*infection_normalize_ratio, np.array(rmse_test)*infection_normalize_ratio
print ("-----------------------------------------")
print ("mae_validate mean", round(np.mean(mae_validate),3), " rmse_validate mean", round(np.mean(rmse_validate),3))
print ("mae_test mean", round(np.mean(mae_test),3), " rmse_test mean", round(np.mean(rmse_test),3))
print ("-----------------------------------------")
rmse_validate mean 0.08250582211053252 rmse_test mean 0.06820357225972093 mae_validate mean 0.05825752305376434 mae_test mean 0.05250513498890561 ----------------------------------------- mae_validate mean 5.826 rmse_validate mean 8.251 mae_test mean 5.251 rmse_test mean 6.82 -----------------------------------------
print(validate_hat[0][Y_day-1])
print(torch.sum(validate_hat[0][Y_day-1]))
print(validate_real[0][Y_day-1])
print(torch.sum(validate_real[0][Y_day-1]))
tensor([0.0381, 0.1415, 0.3247, 0.5011, 0.1707, 0.2769, 0.1918, 0.3770, 0.2975,
0.3419, 0.6821, 0.8111, 0.3522, 0.4191, 0.4904, 0.2453, 0.2995, 0.1959,
0.4605, 0.5326, 0.6195, 0.3227, 0.5396], grad_fn=<SelectBackward>)
tensor(8.6315, grad_fn=<SumBackward0>)
tensor([0.0229, 0.0600, 0.0900, 0.2014, 0.0500, 0.0786, 0.1164, 0.1614, 0.0929,
0.1157, 0.4357, 0.4164, 0.1000, 0.1350, 0.1729, 0.2386, 0.1150, 0.0964,
0.2600, 0.2907, 0.2500, 0.3986, 0.3093], dtype=torch.float64)
tensor(4.2079, dtype=torch.float64)
print (trained_model.v)
Parameter containing:
tensor([1.6562e-01, 8.4237e-02, 4.3068e-02, 1.0835e-03, 1.4547e-01, 6.7214e-02,
6.3140e-02, 2.9542e-02, 4.3886e-02, 5.6127e-02, 4.7995e-04, 4.6931e-05,
4.7609e-02, 3.5730e-02, 9.3220e-03, 4.3771e-02, 5.7581e-02, 7.8378e-02,
5.6774e-03, 8.6750e-04, 6.5533e-05, 8.9712e-03, 8.9543e-04],
requires_grad=True)
x = range(len(rmse_validate))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_validate), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_validate), 'go-',linewidth=0.8, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of validation',fontsize=12)
plt.ylabel("RMSE/MAE daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
x = range(len(mae_test))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_test), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_test), 'go-',linewidth=0.5, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of test',fontsize=12)
plt.ylabel("RMSE/MAE Daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
from scipy import stats
#validate
y_days = Y_day
validate_hat_sum = [float(torch.sum(validate_hat[i][y_days-1])) for i in range(len(validate_hat))]
validate_real_sum = [float(torch.sum(validate_real[i][y_days-1])) for i in range(len(validate_real))]
print ("the correlation between validation: ", stats.pearsonr(validate_hat_sum, validate_real_sum)[0])
#test
test_hat_sum = [float(torch.sum(test_hat[i][y_days-1])) for i in range(len(test_hat))]
test_real_sum = [float(torch.sum(test_real[i][y_days-1])) for i in range(len(test_real))]
print ("the correlation between test: ", stats.pearsonr(test_hat_sum, test_real_sum)[0])
#train
train_result, train_hat, train_real = validate_test_process(trained_model, train_x_y)
train_hat_sum = [float(torch.sum(train_hat[i][0])) for i in range(len(train_hat))]
train_real_sum = [float(torch.sum(train_real[i][0])) for i in range(len(train_real))]
print ("the correlation between train: ", stats.pearsonr(train_hat_sum, train_real_sum)[0])
the correlation between validation: 0.679668723770587 the correlation between test: -0.0689995556723094 the correlation between train: 0.9721069427401245
y1List = [np.sum(list(train_original[i+1][1][Y_day-1].values())) for i in range(len(train_original)-1)]
y2List = [np.sum(list(validate_original[i][1][Y_day-1].values())) for i in range(len(validate_original))]
y2List_hat = [float(torch.sum(validate_hat[i][Y_day-1])) for i in range(len(validate_hat))]
y3List = [np.sum(list(test_original[i][1][Y_day-1].values())) for i in range(len(test_original))]
y3List_hat = [float(torch.sum(test_hat[i][Y_day-1])) for i in range(len(test_hat))]
#x1 = np.array(range(len(y1List)))
#x2 = np.array([len(y1List)+j for j in range(len(y2List))])
x1 = train_list
x2 = validation_list
x3 = np.array([len(y1List)+len(y2List)+j for j in range(len(y3List))])
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x1[0: len(y1List)], np.array(y1List)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='train')
l2 = plt.plot(x2, np.array(y2List)*infection_normalize_ratio, 'go-',linewidth=0.8, markersize=2.0, label='validate')
l3 = plt.plot(x2, np.array(y2List_hat)*infection_normalize_ratio, 'g-',linewidth=2, markersize=0.1, label='validate_predict')
l4 = plt.plot(x3, np.array(y3List)*infection_normalize_ratio, 'bo-',linewidth=0.8, markersize=2, label='test')
l5 = plt.plot(x3, np.array(y3List_hat)*infection_normalize_ratio, 'b-',linewidth=2, markersize=0.1, label='test_predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(5):
summary += 60
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize=8)
plt.yticks(fontsize=12)
plt.title("SpectralGCN")
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
def getPredictionPlot(k):
#location k
x_k = [i for i in range(len(test_real))]
real_k = [test_real[i][y_days-1][k] for i in range(len(test_real))]
predict_k = [test_hat[i][y_days-1][k] for i in range(len(test_hat))]
plt.figure(figsize=(4,2.5), dpi=300)
l1 = plt.plot(x_k, np.array(real_k)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='real',alpha = 0.8)
l2 = plt.plot(x_k, np.array(predict_k)*infection_normalize_ratio, 'o-',color='black',linewidth=0.8, markersize=2.0, alpha = 0.8, label='predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
#plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,100,40)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(6):
summary += 10
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.title("Real and predict daily infection for region "+str(k))
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
for i in range(23):
getPredictionPlot(i)